'''
This module is responsible for creating sockets on interfaces, and for manipulating packets to et informations from them. 4 constants are responsible for return the type code for a specific IEEE 802.11 element:
  IEEE_TLV_TYPE_SSID
  IEEE_TLV_TYPE_CHANNEL
  IEEE_TLV_TYPE_RSN
  IEEE_TLV_TYPE_CSA
  IEEE_TLV_TYPE_VENDOR

4 constants are binary numbers that represents the bit flag position for a RadioTap element:
  IEEE80211_RADIOTAP_RATE = (1 << 2)
  IEEE80211_RADIOTAP_CHANNEL = (1 << 3)
  IEEE80211_RADIOTAP_TX_FLAGS = (1 << 15)
  IEEE80211_RADIOTAP_DATA_RETRIES = (1 << 17)
'''
import logging
logging.getLogger("scapy.runtime").setLevel(logging.ERROR)
from scapy.all import *
from log_messages import *
import socket, struct, subprocess

IEEE_TLV_TYPE_SSID    = 0
IEEE_TLV_TYPE_CHANNEL = 3
IEEE_TLV_TYPE_RSN     = 48
IEEE_TLV_TYPE_CSA     = 37
IEEE_TLV_TYPE_VENDOR  = 221

IEEE80211_RADIOTAP_RATE = (1 << 2)
IEEE80211_RADIOTAP_CHANNEL = (1 << 3)
IEEE80211_RADIOTAP_TX_FLAGS = (1 << 15)
IEEE80211_RADIOTAP_DATA_RETRIES = (1 << 17)

class MitmSocket(L2Socket):
	'''
    Module: packet_processing
    ===

    Descripion: open a new scoket on iface interface

    Arguments:
      dumpfile: pcap file to write captured and sent packets
      strict_echo_test: accept echoed injected packets
      iface: the interface on which to create the socket
      type: the type of protocol accepted by the socket
    '''
	def __init__(self, dumpfile=None, strict_echo_test=False, **kwargs):
		super(MitmSocket, self).__init__(**kwargs)
		self.pcap = None
		if dumpfile:
			self.pcap = PcapWriter("%s.%s.pcap" % (dumpfile, self.iface), append=False, sync=True)
		self.strict_echo_test = strict_echo_test

	def set_channel(self, channel):
		'''
		Module: packet_processing
		===
		Class: MitmSocket
		---
		Description: configure a new channel for the interface

		Arguments:
		  channel: number of the new channel
		'''
		subprocess.check_output(["iw", self.iface, "set", "channel", str(channel)])

	def attach_filter(self, bpf):
		'''
		Module: packet_processing
		===
		Class: MitmSocket
		---

		Description: creates a packet filter for the interface

		Arguments:
		  bpf: string containing filter rules
		'''
		log(DEBUG, "Attaching filter to %s: <%s>" % (self.iface, bpf))
		attach_filter(self.ins, bpf, self.iface)

	def send(self, p):
		'''
        Module: packet_processing
        ===
        Class: MitmSocket
        ---

		Description: inject new packet on network

		Arguments:
		  p: packet to send
		'''
		# Hack: set the More Data flag so we can detect injected frames
		p[Dot11].FCfield |= 0x20
		L2Socket.send(self, RadioTap()/p)
		if self.pcap: self.pcap.write(RadioTap()/p)
		log(DEBUG, "%s: Injected frame %s" % (self.iface, dot11_to_str(p)))

	def _strip_fcs(self, p):
		'''
		Module: packet_processing
		===
		Class: MitmSocket
		---

		Description: remove the FCS from the packet, if it's present

		Arguments:
		  p: packet from which to remove the FCS
		'''
		# Scapy can't handle FCS field automatically
		if p[RadioTap].present & 2 != 0:
			rawframe = str(p[RadioTap])
			pos = 8
			while ord(rawframe[pos - 1]) & 0x80 != 0: pos += 4
		
			# If the TSFT field is present, it must be 8-bytes aligned
			if p[RadioTap].present & 1 != 0:
				pos += (8 - (pos % 8))
				pos += 8

			# Remove FCS if present
			if ord(rawframe[pos]) & 0x10 != 0:
				return Dot11(str(p[Dot11])[:-4])

		return p[Dot11]

	def recv(self, x=MTU):
		'''
		Module: packet_processing
		===
		Class: MitmSocket
		---
		
		Description: listen to 802.11 packets
		
		Obs.: it has some rules to reject echoed injected packets
		'''
		p = L2Socket.recv(self, x)
		if p == None or not Dot11 in p: return None
		if self.pcap: self.pcap.write(p)

		# Don't care about control frames
		if p.type == 1:
			log(ALL, "%s: ignoring control frame %s" % (self.iface, dot11_to_str(p)))
			return None

		# 1. Radiotap monitor mode header is defined in ieee80211_add_tx_radiotap_header: TX_FLAGS, DATA_RETRIES, [RATE, MCS, VHT, ]
		# 2. Radiotap header for normal received frames is defined in ieee80211_add_rx_radiotap_header: FLAGS, CHANNEL, RX_FLAGS, [...]
		# 3. Beacons generated by hostapd and recieved on virtual interface: TX_FLAGS, DATA_RETRIES
		#
		# Conclusion: if channel flag is not present, but rate flag is included, then this could be an echoed injected frame.
		# Warning: this check fails to detect injected frames captured by the other interface (due to proximity of transmittors and capture effect)
		radiotap_possible_injection = (p[RadioTap].present & IEEE80211_RADIOTAP_CHANNEL == 0) and not (p[RadioTap].present & IEEE80211_RADIOTAP_RATE == 0)

		# Hack: ignore frames that we just injected and are echoed back by the kernel. Note that the More Data flag also
		#	allows us to detect cross-channel frames (received due to proximity of transmissors on different channel)
		if p[Dot11].FCfield & 0x20 != 0 and (not self.strict_echo_test or self.radiotap_possible_injection):
			log(DEBUG, "%s: ignoring echoed frame %s (0x%02X, present=%08X, strict=%d)" % (self.iface, dot11_to_str(p), p[Dot11].FCfield, p[RadioTap].present, radiotap_possible_injection))
			return None
		else:
			log(ALL, "%s: Received frame: %s" % (self.iface, dot11_to_str(p)))

		# Strip the FCS if present, and drop the RadioTap header
		return self._strip_fcs(p)

	def close(self):
		'''
		Module: packet_processing
		===
		Class: MitmSocket
		---
		Description: close the socket connection and the pcap file
		'''
		if self.pcap: self.pcap.close()
		super(MitmSocket, self).close()

def call_macchanger(iface, macaddr):
	'''
	Module: packet_processing
	===
	Description: it calls the macchanger function without 'down' and 'up' operation on the interface

	Arguments:
	  iface: wifi interface
	  macaddr: MAC address
	'''
	try:
		subprocess.check_output(["macchanger", "-m", macaddr, iface])
	except subprocess.CalledProcessError, ex:
		if not "It's the same MAC!!" in ex.output:
			raise

def set_mac_address(iface, macaddr):
	'''
	Module: packet_processing
	===
	Description: set the macaddress of iface to macaddr

	Arguments:
	  iface: wifi interface
	  macaddr: MAC address
	'''
	subprocess.check_output(["ifconfig", iface, "down"])
	call_macchanger(iface, macaddr)
	subprocess.check_output(["ifconfig", iface, "up"])

def set_monitor_ack_address(iface, macaddr, sta_suffix=None):
	"""Add a virtual STA interface for ACK generation. This assumes nothing takes control of this
		 interface, meaning it remains on the current channel."""
	sta_iface = iface + ("sta" if sta_suffix is None else sta_suffix)
	subprocess.call(["iw", sta_iface, "del"], stdout=subprocess.PIPE, stdin=subprocess.PIPE)
	subprocess.check_output(["iw", iface, "interface", "add", sta_iface, "type", "managed"])
	call_macchanger(sta_iface, macaddr)
	subprocess.check_output(["ifconfig", sta_iface, "up"])

def xorstr(lhs, rhs):
	'''
	Module: packet_processing
	===
	Description: execute the XOR operation between a choosen plaintext and a ciphertext to obtain a keystream

	Arguments:
	  lhs: plaintext
	  rhs: ciphertext
	'''
	return "".join([chr(ord(lb) ^ ord(rb)) for lb, rb in zip(lhs, rhs)])

def dot11_get_seqnum(p):
	'''
	Module: packet_processing
	===
	Description: get the sequence number from the 802.11 MAC header

	Arguments:
	  p: packet from which to obtain the sequence number
	'''
	return p[Dot11].SC >> 4

def dot11_get_iv(p):
	'''
	Module: packet_processing
	===
	Description: obtain the IV value from a TKIP or CCMP packet

	Arguments:
	  p: packet from which to obtain the IV
	'''
	if Dot11WEP not in p:
		log(ERROR, "INTERNAL ERROR: Requested IV of plaintext frame")
		return 0

	wep = p[Dot11WEP]
	if wep.keyid & 32:
		return ord(wep.iv[0]) + (ord(wep.iv[1]) << 8) + (struct.unpack("<I", wep.wepdata[:4])[0] << 16)
	else:
		return ord(wep.iv[0]) + (ord(wep.iv[1]) << 8) + (ord(wep.iv[2]) << 16)

def dot11_get_tid(p):
	'''
	Module: packet_processing
	===
	Descpriton: if it's present, obtain the TID (Traffic ID) number from the QoS

	Arguments:
	  p: packet from which to obtain the TID
	'''
	if Dot11QoS in p:
		return ord(str(p[Dot11QoS])[0]) & 0x0F
	return 0

def dot11_is_group(p):
	'''
	Module: packet_processing
	===
	Description: determine if the packet is a broadcast packet

	Arguments:
	  p: packet
	'''
	# TODO: Detect if multicast bit is set in p.addr1
	return p.addr1 == "ff:ff:ff:ff:ff:ff"

def get_eapol_msgnum(p):
	'''
	Module: packet_processing
	===
	Description: gets the number of the EAPOL message

	Arguments:
	  p: EAPOL packet
	'''
	FLAG_PAIRWISE = 0b0000001000
	FLAG_ACK      = 0b0010000000
	FLAG_SECURE   = 0b1000000000

	if not EAPOL in p: return 0

	keyinfo = str(p[EAPOL])[5:7]
	flags = struct.unpack(">H", keyinfo)[0]
	if flags & FLAG_PAIRWISE:
		# 4-way handshake
		if flags & FLAG_ACK:
			# sent by server
			if flags & FLAG_SECURE: return 3
			else: return 1
		else:
			# sent by server
			# FIXME: use p[EAPOL.load] instead of str(p[EAPOL])
			keydatalen = struct.unpack(">H", str(p[EAPOL])[97:99])[0]
			if keydatalen == 0: return 4
			else: return 2

	return 0

def get_eapol_replaynum(p):
	'''
	Module: packet_processing
	===
	Description: get the Replay Counter from the EAPOL packet

	Argument:
	  p: EAPOL packet
	'''
	# FIXME: use p[EAPOL.load] instead of str(p[EAPOL])
	return struct.unpack(">Q", str(p[EAPOL])[9:17])[0]

def set_eapol_replaynum(p, value):
	'''
	Module: packet_processing
	===
	Description: set the Replay Counter of the EAPOL packet to value
	'''
	p[EAPOL].load = p[EAPOL].load[:5] + struct.pack(">Q", value) + p[EAPOL].load[13:]
	return p

def dot11_to_str(p):
	'''
	Module: packet_processing
	===
	Description: return the string representation of the packet

	Arguments:
	  p: 802.11 packet 
	'''
	EAP_CODE = {1: "Request"}
	EAP_TYPE = {1: "Identity"}
	DEAUTH_REASON = {1: "Unspecified", 2: "Prev_Auth_No_Longer_Valid/Timeout", 3: "STA_is_leaving", 4: "Inactivity", 6: "Unexp_Class2_Frame",
		7: "Unexp_Class3_Frame", 8: "Leaving", 15: "4-way_HS_timeout"}
	dict_or_str = lambda d, v: d.get(v, str(v))
	if p.type == 0:
		if Dot11Beacon in p:     return "Beacon(seq=%d, TSF=%d)" % (dot11_get_seqnum(p), p[Dot11Beacon].timestamp)
		if Dot11ProbeReq in p:   return "ProbeReq(seq=%d)" % dot11_get_seqnum(p)
		if Dot11ProbeResp in p:  return "ProbeResp(seq=%d)" % dot11_get_seqnum(p)
		if Dot11Auth in p:       return "Auth(seq=%d, status=%d)" % (dot11_get_seqnum(p), p[Dot11Auth].status)
		if Dot11Deauth in p:     return "Deauth(seq=%d, reason=%s)" % (dot11_get_seqnum(p), dict_or_str(DEAUTH_REASON, p[Dot11Deauth].reason))
		if Dot11AssoReq in p:    return "AssoReq(seq=%d)" % dot11_get_seqnum(p)
		if Dot11ReassoReq in p:  return "ReassoReq(seq=%d)" % dot11_get_seqnum(p)
		if Dot11AssoResp in p:   return "AssoResp(seq=%d, status=%d)" % (dot11_get_seqnum(p), p[Dot11AssoResp].status)
		if Dot11ReassoResp in p: return "ReassoResp(seq=%d, status=%d)" % (dot11_get_seqnum(p), p[Dot11ReassoResp].status)
		if Dot11Disas in p:      return "Disas(seq=%d)" % dot11_get_seqnum(p)
		if p.subtype == 13:      return "Action(seq=%d)" % dot11_get_seqnum(p)
	elif p.type == 1:
		if p.subtype ==  9:      return "BlockAck"
		if p.subtype == 11:      return "RTS"
		if p.subtype == 13:      return "Ack"
	elif p.type == 2:
		if Dot11WEP in p:        return "EncryptedData(seq=%d, IV=%d)" % (dot11_get_seqnum(p), dot11_get_iv(p))
		if p.subtype == 4:       return "Null(seq=%d, sleep=%d)" % (dot11_get_seqnum(p), p.FCfield & 0x10 != 0)
		if p.subtype == 12:      return "QoS-Null(seq=%d, sleep=%d)" % (dot11_get_seqnum(p), p.FCfield & 0x10 != 0)
		if EAPOL in p:
			if get_eapol_msgnum(p) != 0: return "EAPOL-Msg%d(seq=%d,replay=%d)" % (get_eapol_msgnum(p), dot11_get_seqnum(p), get_eapol_replaynum(p))
			elif EAP in p:       return "EAP-%s,%s(seq=%d)" % (dict_or_str(EAP_CODE, p[EAP].code), dict_or_str(EAP_TYPE, p[EAP].type), dot11_get_seqnum(p))
			else:                return repr(p)
	return repr(p)			

def construct_csa(channel, count=1):
	'''
	Module: packet_processing
	===
	Description: constructs the CSA element

	Arguments:
	  channel: the new channel
	  count: the number of counts until the station switch
	'''
	switch_mode = 1			# STA should not Tx untill switch is completed
	new_chan_num = channel	# Channel it should switch to
	switch_count = count	# Immediately make the station switch

	# Contruct the IE
	payload = struct.pack("<BBB", switch_mode, new_chan_num, switch_count)
	return Dot11Elt(ID=IEEE_TLV_TYPE_CSA, info=payload)

def append_csa(p, channel, count=1):
	'''
	Module: packet_processing
	===
	Descprition: appends the CSA element to the packet p

	Arguments:
	  p: 802.11 Beacon packet
	  channel: the new channel
	  count: the number of counts until the station switch
	'''
	p = p.copy()

	el = p[Dot11Elt]
	prevel = None
	while isinstance(el, Dot11Elt):
		prevel = el
		el = el.payload

	prevel.payload = construct_csa(channel, count)

	return p

def get_tlv_value(p, type):
	'''
	Module: packet_processing
	===
	Description: gets the value from an element of type == `type`

	Arguments:
	  p: 802.11 packet
	  type: type of the element from which to get the value
	'''
	if not Dot11Elt in p: return None
	el = p[Dot11Elt]
	while isinstance(el, Dot11Elt):
		if el.ID == type:
			return el.info
		el = el.payload
	return None